import json
import numpy as np
import matplotlib.pyplot as plt

output_train = open("./Distortion/distorsion_train.json", "r")
auc_vals_train = json.load(output_train)
output_train.close()
output_test = open("./Distortion/distorsion_test.json", "r")
auc_vals_test = json.load(output_test)
output_test.close()

n_samples = 50
granularity = 100
jump_size = 2
additive_distortion = False
x = np.divide(list(range(0, int(n_samples * jump_size), jump_size)), granularity)
y_ticks = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

fig, subplots = plt.subplots(1, 2, sharey=True, figsize=(100/9, 4))
fig.subplots_adjust(wspace=0.02)
fig.suptitle("Distortion Curves")
subplots[0].set_yticks(y_ticks)
subplots[0].set_ylim([0, 1.1])

for name, auc_vals in auc_vals_train.items():
    subplots[0].plot(x, auc_vals, label=name)
subplots[0].set_title("Train")
subplots[0].legend()
if additive_distortion:
    subplots[0].set_xlabel("Gaussian Sigma")
else:
    subplots[0].set_xlabel("Pixel Removal Ratio")
subplots[0].set_ylabel("Accuracy")

for name, auc_vals in auc_vals_test.items():
    subplots[1].plot(x, auc_vals, label=name)
subplots[1].set_title("Test")
subplots[1].legend()
if additive_distortion:
    subplots[1].set_xlabel("Gaussian Sigma")
else:
    subplots[1].set_xlabel("Pixel Removal Ratio")

fig.savefig("./Distortion/distortion_curves.png", bbox_inches="tight")
